{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2024 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      },
      "outputs": [],
      "source": [
        "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "introduction"
      },
      "source": [
        "# Fine-tuning Gemma with Torch XLA and Hugging Face TRL\n",
        "\n",
        "Welcome to this step-by-step guide on fine-tuning the [Gemma](https://huggingface.co/google/gemma-2b) using [Torch XLA](https://github.com/pytorch/xla).\n",
        "\n",
        "\n",
        "[**Gemma**](https://ai.google.dev/gemma) is a family of lightweight, state-of-the-art open models from Google, built from the same research and technology used to create the Gemini models. They are text-to-text, decoder-only large language models, available in English, with open weights, pre-trained variants, and instruction-tuned variants. Gemma models are well-suited for a variety of text generation tasks, including question answering, summarization, and reasoning. Their relatively small size makes it possible to deploy them in environments with limited resources such as a laptop, desktop or your own cloud infrastructure, democratizing access to state of the art AI models and helping foster innovation for everyone.\n",
        "\n",
        "[**Torch XLA**](https://pytorch.org/xla/) enables you to leverage the computational power of TPUs (Tensor Processing Units) for efficient training of deep learning models. By interfacing PyTorch with the [XLA (Accelerated Linear Algebra)](https://openxla.org/xla) compiler, Torch XLA translates PyTorch operations into XLA operations that can be executed on TPUs. This means you can write your models in PyTorch as usual, and Torch XLA handles the underlying computations to run them efficiently on TPUs.\n",
        "\n",
        "[**Transformer Reinforcement Learning (TRL)**](https://github.com/huggingface/trl) is a framework developed by Hugging Face to fine-tune and align both transformer language and diffusion models using methods such as Supervised Fine-Tuning (SFT), Reward Modeling (RM), Proximal Policy Optimization (PPO), Direct Preference Optimization (DPO), and others.\n",
        "\n",
        "Integrating PyTorch with XLA allows developers to run PyTorch code on TPUs with minimal changes to their existing codebase. This seamless integration provides the performance benefits of TPUs while maintaining the flexibility and ease of use of the PyTorch framework.\n",
        "\n",
        "By the end of this notebook, you will learn:\n",
        "\n",
        "- About Torch XLA\n",
        "- How to peform **Parameter-Efficient Fine-Tuning (PEFT)** with the **Low-Rank Adaptation (LoRA)** on [Gemma 2 2B](https://huggingface.co/google/gemma-2-2b) using Hugging Face's **TRL** framework, **Torch XLA** and TPUs.\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_2]Finetune_with_Torch_XLA.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "</table>\n",
        "<br><br>\n",
        "\n",
        "[![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://www.kaggle.com/notebooks/welcome?src=https://github.com/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_2]Finetune_with_Torch_XLA.ipynb)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "L0Srgc7OGVJj"
      },
      "source": [
        "## Setup\n",
        "\n",
        "### Selecting the Runtime Environment\n",
        "\n",
        "To start, you can choose either **Google Colab** or **Kaggle** as your platform. Select one, and proceed from there.\n",
        "\n",
        "- #### **Google Colab** <img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/d/d0/Google_Colaboratory_SVG_Logo.svg/1200px-Google_Colaboratory_SVG_Logo.svg.png\" alt=\"Google Colab\" width=\"30\"/>\n",
        "\n",
        "  1. Click **Open in Colab**.\n",
        "  2. In the menu, go to **Runtime** > **Change runtime type**.\n",
        "  3. Under **Hardware accelerator**, select **TPU**.\n",
        "  4. Ensure that the **TPU type** is set to **TPU v2-8**.\n",
        "\n",
        "- #### **Kaggle** <img src=\"https://upload.wikimedia.org/wikipedia/commons/7/7c/Kaggle_logo.png\" alt=\"Kaggle\" width=\"40\"/>\n",
        "\n",
        "  1. Click **Open in Kaggle**.\n",
        "  2. Click on **Settings** in the right sidebar.\n",
        "  3. Under **Accelerator**, select **TPUs**.\n",
        "    - Note: Kaggle currently provides **TPU v3-8**.\n",
        "  4. Save the settings, and the notebook will restart with TPU support.\n",
        "\n",
        "\n",
        "### Gemma using Hugging Face\n",
        "\n",
        "Before diving into the tutorial, let's set up Gemma:\n",
        "\n",
        "1. **Create a Hugging Face Account**: If you don't have one, you can sign up for a free account [here](https://huggingface.com/join).\n",
        "2. **Access the Gemma Model**: Visit the [Gemma model page](https://huggingface.com/collections/google/gemma-2-release-667d6600fd5220e7b967f315) and accept the usage conditions.\n",
        "3. **Generate a Hugging Face Token**: Go to your Hugging Face [settings page](https://huggingface.com/settings/tokens) and generate a new access token (preferably with `write` permissions). You'll need this token later in the tutorial.\n",
        "\n",
        "**Once you've completed these steps, you're ready to move on to the next section where you'll set up environment variables in your Colab environment.**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "configure-credentials"
      },
      "source": [
        "### Configure Your Credentials\n",
        "\n",
        "To access private models and datasets, you need to log in to the Hugging Face (HF) ecosystem.\n",
        "\n",
        "- #### **Google Colab** <img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/d/d0/Google_Colaboratory_SVG_Logo.svg/1200px-Google_Colaboratory_SVG_Logo.svg.png\" alt=\"Google Colab\" width=\"30\"/>\n",
        "  If you're using Colab, you can securely store your Hugging Face token (`HF_TOKEN`) using the Colab Secrets manager:\n",
        "  1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel. <img src=\"https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg\" alt=\"The Secrets tab is found on the left panel.\" width=50%>\n",
        "  2. **Add Hugging Face Token**:\n",
        "    - Create a new secret with the **name** `HF_TOKEN`.\n",
        "    - Copy/paste your token key into the **Value** input box of `HF_TOKEN`.\n",
        "    - **Toggle** the button on the left to allow notebook access to the secret\n",
        "\n",
        "- #### **Kaggle** <img src=\"https://upload.wikimedia.org/wikipedia/commons/7/7c/Kaggle_logo.png\" alt=\"Kaggle\" width=\"40\"/>\n",
        "  To securely use your Hugging Face token (`HF_TOKEN`) in this notebook, you'll need to add it as a secret in your Kaggle environment:  \n",
        "  1. Open your Kaggle notebook and locate the **Addons** menu at the top in your notebook interface.\n",
        "  2. Click on **Secrets** to manage your environment secrets.  \n",
        "  <img src=\"https://i.imgur.com/vxrtJuM.png\" alt=\"The Secrets option is found at the top.\" width=50%>\n",
        "  3. **Add Hugging Face Token**:\n",
        "      - Click on the **Add secret** button.\n",
        "      - In the **Label** field, enter `HF_TOKEN`.  \n",
        "      - In the **Value** field, paste your Hugging Face token.\n",
        "      - Click **Save** to add the secret."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mBPIzOqnGmt-"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import sys\n",
        "\n",
        "if 'google.colab' in sys.modules:\n",
        "    # Running on Colab\n",
        "    from google.colab import userdata\n",
        "    os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')\n",
        "elif os.path.exists('/kaggle/working'):\n",
        "    # Running on Kaggle\n",
        "    from kaggle_secrets import UserSecretsClient\n",
        "    user_secrets = UserSecretsClient()\n",
        "    os.environ['HF_TOKEN'] = user_secrets.get_secret(\"HF_TOKEN\")\n",
        "else:\n",
        "    # Not running on Colab or Kaggle\n",
        "    raise EnvironmentError('This notebook is designed to run on Google Colab or Kaggle.')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ys5-2RMfUroM"
      },
      "source": [
        "This code retrieves your secrets and sets them as environment variables, which you will use later in the tutorial."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "setting-up-environment"
      },
      "source": [
        "### Setting Up the Environment\n",
        "\n",
        "Next, you'll set up the environment by installing all the necessary Python packages for fine-tuning the Gemma model on a TPU VM using Torch XLA.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "setup-code"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Found existing installation: tensorflow 2.15.0\n",
            "Uninstalling tensorflow-2.15.0:\n",
            "  Successfully uninstalled tensorflow-2.15.0\n",
            "Found existing installation: tf_keras 2.15.1\n",
            "Uninstalling tf_keras-2.15.1:\n",
            "  Successfully uninstalled tf_keras-2.15.1\n",
            "Collecting tensorflow==2.18.0\n",
            "  Downloading tensorflow-2.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)\n",
            "Collecting tf-keras==2.18.0\n",
            "  Downloading tf_keras-2.18.0-py3-none-any.whl.metadata (1.6 kB)\n",
            "Requirement already satisfied: absl-py>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (1.4.0)\n",
            "Requirement already satisfied: astunparse>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (1.6.3)\n",
            "Requirement already satisfied: flatbuffers>=24.3.25 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (24.3.25)\n",
            "Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (0.6.0)\n",
            "Requirement already satisfied: google-pasta>=0.1.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (0.2.0)\n",
            "Requirement already satisfied: libclang>=13.0.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (18.1.1)\n",
            "Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (3.4.0)\n",
            "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (24.2)\n",
            "Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.3 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (4.25.5)\n",
            "Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (2.32.3)\n",
            "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (75.1.0)\n",
            "Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (1.16.0)\n",
            "Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (2.5.0)\n",
            "Requirement already satisfied: typing-extensions>=3.6.6 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (4.12.2)\n",
            "Requirement already satisfied: wrapt>=1.11.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (1.14.1)\n",
            "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (1.68.1)\n",
            "Collecting tensorboard<2.19,>=2.18 (from tensorflow==2.18.0)\n",
            "  Downloading tensorboard-2.18.0-py3-none-any.whl.metadata (1.6 kB)\n",
            "Collecting keras>=3.5.0 (from tensorflow==2.18.0)\n",
            "  Downloading keras-3.7.0-py3-none-any.whl.metadata (5.8 kB)\n",
            "Requirement already satisfied: numpy<2.1.0,>=1.26.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (1.26.4)\n",
            "Requirement already satisfied: h5py>=3.11.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (3.12.1)\n",
            "Collecting ml-dtypes<0.5.0,>=0.4.0 (from tensorflow==2.18.0)\n",
            "  Downloading ml_dtypes-0.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)\n",
            "Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (0.37.1)\n",
            "Requirement already satisfied: wheel<1.0,>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from astunparse>=1.6.0->tensorflow==2.18.0) (0.45.1)\n",
            "Requirement already satisfied: rich in /usr/local/lib/python3.10/dist-packages (from keras>=3.5.0->tensorflow==2.18.0) (13.9.4)\n",
            "Collecting namex (from keras>=3.5.0->tensorflow==2.18.0)\n",
            "  Downloading namex-0.0.8-py3-none-any.whl.metadata (246 bytes)\n",
            "Collecting optree (from keras>=3.5.0->tensorflow==2.18.0)\n",
            "  Downloading optree-0.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (47 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m47.8/47.8 kB\u001b[0m \u001b[31m3.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorflow==2.18.0) (3.4.0)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorflow==2.18.0) (3.10)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorflow==2.18.0) (2.2.3)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorflow==2.18.0) (2024.8.30)\n",
            "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.19,>=2.18->tensorflow==2.18.0) (3.7)\n",
            "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.19,>=2.18->tensorflow==2.18.0) (0.7.2)\n",
            "Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.19,>=2.18->tensorflow==2.18.0) (3.1.3)\n",
            "Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.10/dist-packages (from werkzeug>=1.0.1->tensorboard<2.19,>=2.18->tensorflow==2.18.0) (3.0.2)\n",
            "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich->keras>=3.5.0->tensorflow==2.18.0) (3.0.0)\n",
            "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich->keras>=3.5.0->tensorflow==2.18.0) (2.18.0)\n",
            "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich->keras>=3.5.0->tensorflow==2.18.0) (0.1.2)\n",
            "Downloading tensorflow-2.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (615.3 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m615.3/615.3 MB\u001b[0m \u001b[31m1.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading tf_keras-2.18.0-py3-none-any.whl (1.7 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.7/1.7 MB\u001b[0m \u001b[31m311.1 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading keras-3.7.0-py3-none-any.whl (1.2 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.2/1.2 MB\u001b[0m \u001b[31m57.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading ml_dtypes-0.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.2/2.2 MB\u001b[0m \u001b[31m73.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading tensorboard-2.18.0-py3-none-any.whl (5.5 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.5/5.5 MB\u001b[0m \u001b[31m101.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading namex-0.0.8-py3-none-any.whl (5.8 kB)\n",
            "Downloading optree-0.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (381 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m381.3/381.3 kB\u001b[0m \u001b[31m24.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hInstalling collected packages: namex, optree, ml-dtypes, tensorboard, keras, tensorflow, tf-keras\n",
            "  Attempting uninstall: ml-dtypes\n",
            "    Found existing installation: ml-dtypes 0.2.0\n",
            "    Uninstalling ml-dtypes-0.2.0:\n",
            "      Successfully uninstalled ml-dtypes-0.2.0\n",
            "  Attempting uninstall: tensorboard\n",
            "    Found existing installation: tensorboard 2.15.2\n",
            "    Uninstalling tensorboard-2.15.2:\n",
            "      Successfully uninstalled tensorboard-2.15.2\n",
            "  Attempting uninstall: keras\n",
            "    Found existing installation: keras 2.15.0\n",
            "    Uninstalling keras-2.15.0:\n",
            "      Successfully uninstalled keras-2.15.0\n",
            "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "tensorflow-text 2.15.0 requires tensorflow<2.16,>=2.15.0; platform_machine != \"arm64\" or platform_system != \"Darwin\", but you have tensorflow 2.18.0 which is incompatible.\u001b[0m\u001b[31m\n",
            "\u001b[0mSuccessfully installed keras-3.7.0 ml-dtypes-0.4.1 namex-0.0.8 optree-0.13.1 tensorboard-2.18.0 tensorflow-2.18.0 tf-keras-2.18.0\n",
            "Found existing installation: tensorflow 2.18.0\n",
            "Uninstalling tensorflow-2.18.0:\n",
            "  Successfully uninstalled tensorflow-2.18.0\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m230.0/230.0 MB\u001b[0m \u001b[31m4.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.1/44.1 kB\u001b[0m \u001b[31m1.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.0/10.0 MB\u001b[0m \u001b[31m75.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m69.2/69.2 kB\u001b[0m \u001b[31m2.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m480.6/480.6 kB\u001b[0m \u001b[31m13.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m9.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m179.3/179.3 kB\u001b[0m \u001b[31m15.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m51.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m10.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m15.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m241.9/241.9 kB\u001b[0m \u001b[31m20.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m124.6/124.6 kB\u001b[0m \u001b[31m10.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m205.1/205.1 kB\u001b[0m \u001b[31m15.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m319.7/319.7 kB\u001b[0m \u001b[31m25.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m310.2/310.2 kB\u001b[0m \u001b[31m6.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m320.7/320.7 kB\u001b[0m \u001b[31m21.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m324.3/324.3 kB\u001b[0m \u001b[31m6.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: tpu-info in /usr/local/lib/python3.10/dist-packages (0.2.0)\n",
            "Requirement already satisfied: grpcio>=1.65.5 in /usr/local/lib/python3.10/dist-packages (from tpu-info) (1.68.1)\n",
            "Requirement already satisfied: protobuf in /usr/local/lib/python3.10/dist-packages (from tpu-info) (4.25.5)\n",
            "Requirement already satisfied: rich in /usr/local/lib/python3.10/dist-packages (from tpu-info) (13.9.4)\n",
            "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich->tpu-info) (3.0.0)\n",
            "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich->tpu-info) (2.18.0)\n",
            "Requirement already satisfied: typing-extensions<5.0,>=4.0.0 in /usr/local/lib/python3.10/dist-packages (from rich->tpu-info) (4.12.2)\n",
            "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich->tpu-info) (0.1.2)\n"
          ]
        }
      ],
      "source": [
        "# Uninstalling any existing TensorFlow installations and then install the CPU-only version to avoid conflicts while using the TPU.\n",
        "!pip uninstall -y tensorflow tf-keras\n",
        "!pip install tensorflow==2.18.0 tf-keras==2.18.0\n",
        "\n",
        "!pip uninstall tensorflow -y\n",
        "!pip install tensorflow-cpu==2.18.0 -q\n",
        "\n",
        "# Install the appropriate Hugging Face libraries to ensure compatibility with the Gemma model and PEFT.\n",
        "!pip install transformers==4.46.1 -U -q\n",
        "!pip install datasets==3.1.0 -U -q\n",
        "!pip install trl==0.12.0 peft==0.13.2 -U -q\n",
        "!pip install accelerate==0.34.0 -U -q\n",
        "\n",
        "# Install PyTorch and Torch XLA with versions compatible with the TPU runtime, ensuring efficient TPU utilization.\n",
        "!pip install -qq torch~=2.5.0 --index-url https://download.pytorch.org/whl/cpu\n",
        "!pip install -qq torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html\n",
        "\n",
        "# Install the `tpu-info` package to display TPU-related information\n",
        "!pip install tpu-info"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "note"
      },
      "source": [
        "**Note**: Ensure that your PyTorch and Torch XLA versions are compatible with the TPU you're using."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1M7YM_apWfOk"
      },
      "source": [
        "### Verify TPU Setup\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Yjo5_5xfVYNm"
      },
      "source": [
        "You run `!tpu-info` to verify the TPU has been properly initialized."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "skAvSa5KF65m"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[3mTPU Chips                                     \u001b[0m\n",
            "┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━┓\n",
            "┃\u001b[1m \u001b[0m\u001b[1mChip       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mType       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mDevices\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mPID \u001b[0m\u001b[1m \u001b[0m┃\n",
            "┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━┩\n",
            "│ /dev/accel0 │ TPU v2 chip │ 2       │ None │\n",
            "│ /dev/accel1 │ TPU v2 chip │ 2       │ None │\n",
            "│ /dev/accel2 │ TPU v2 chip │ 2       │ None │\n",
            "│ /dev/accel3 │ TPU v2 chip │ 2       │ None │\n",
            "└─────────────┴─────────────┴─────────┴──────┘\n",
            "Libtpu metrics unavailable. Is there a framework using the TPU? See https://github.com/google/cloud-accelerator-diagnostics/tree/main/tpu_info for more information\n"
          ]
        }
      ],
      "source": [
        "!tpu-info"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o0zsGkzqWien"
      },
      "source": [
        "If everything is set up correctly, you should see the TPU details printed out.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "importing-libraries"
      },
      "source": [
        "## Import the libraries\n",
        "\n",
        "Now, import all the necessary libraries required for fine-tuning.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "import-code"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "PyTorch version: 2.5.1+cpu\n",
            "Torch XLA version: 2.5.1+libtpu\n"
          ]
        }
      ],
      "source": [
        "import pandas as pd\n",
        "\n",
        "import torch\n",
        "print(f\"PyTorch version: {torch.__version__}\")\n",
        "\n",
        "import torch_xla\n",
        "print(f\"Torch XLA version: {torch_xla.__version__}\")\n",
        "\n",
        "import torch_xla.core.xla_model as xm\n",
        "import torch_xla.runtime as xr\n",
        "\n",
        "from transformers import (\n",
        "    AutoTokenizer,\n",
        "    AutoModelForCausalLM,\n",
        ")\n",
        "from trl import SFTTrainer, SFTConfig\n",
        "from peft import LoraConfig, PeftModel\n",
        "\n",
        "from datasets import load_dataset\n",
        "\n",
        "# Enable Single Program Multiple Data (SPMD) mode,\n",
        "# which allows for parallel execution across multiple TPU cores\n",
        "xr.use_spmd()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GSDScQniWo4L"
      },
      "source": [
        "This setup ensures that your environment is correctly configured to use TPUs with PyTorch.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "finetuning-peft"
      },
      "source": [
        "## Fine-tune using PEFT and LoRA\n",
        "\n",
        "Traditional fine-tuning of large language models (LLMs) like Gemma requires adjusting billions of parameters, making it resource-intensive. This process demands significant computational power and time, which can be impractical for many use cases. That's where Parameter-Efficient Fine-Tuning (PEFT) techniques come in.\n",
        "\n",
        "### Parameter-Efficient Fine-Tuning (PEFT)\n",
        "\n",
        "PEFT allows you to adapt large models to specific tasks by updating only a small portion of their parameters. Instead of retraining the entire model, PEFT adds lightweight layers or adapters. Most of the pre-trained weights remain frozen. This approach greatly reduces the computational requirements and the amount of data needed for fine-tuning, making it feasible to fine-tune large models even on modest hardware.\n",
        "\n",
        "### Low-Rank Adaptation (LoRA)\n",
        "\n",
        "Among these techniques, one effective option is Low-Rank Adaptation (LoRA). LoRA introduces small, trainable matrices into the model's architecture, specifically targeting the attention layers of Transformer models. Instead of updating the full weight matrices, LoRA adds rank-decomposed matrices, making adaptation more efficient.\n",
        "\n",
        "#### Key Advantages of LoRA\n",
        "\n",
        "- **Efficiency**: LoRA significantly reduces the number of trainable parameters by using low-rank adaptations, making the fine-tuning process much more efficient.\n",
        "- **Memory Savings**: Since only the additional low-rank matrices are updated, GPU/TPU memory requirements are considerably lower.\n",
        "- **Modularity**: LoRA adapters can be easily merged with the original model or kept separate, offering flexibility for deployment.\n",
        "\n",
        "In the next section, you'll explore how to implement PEFT with LoRA to fine-tune Gemma using Torch XLA on TPUs and perform the following steps:\n",
        "\n",
        "- Load a dataset\n",
        "- Configure the training parameters\n",
        "- Load the Gemma model and tokenizer\n",
        "- Fine-tune the model on TPUs using **TRL**'s `SFTTrainer` class"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "create-dataset"
      },
      "source": [
        "### Load a dataset\n",
        "\n",
        "For this guide, you'll use an existing dataset from Hugging Face. You can replace it with your own dataset if you prefer.\n",
        "\n",
        "The dataset chosen for this guide is [**hieunguyenminh/roleplay**](https://huggingface.com/datasets/hieunguyenminh/roleplay), which embodies a wide range of original characters, each with a unique persona. It includes fictional characters, complete with their own backgrounds, core traits, relationships, goals, and distinct speaking styles.\n",
        "\n",
        "**Credits:** **https://huggingface.com/hieunguyenminh**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "j01SyZuIuU6E"
      },
      "source": [
        "You specify the dataset name and use the `load_dataset` function from the Hugging Face `datasets` library to load the training split of the dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dataset-code"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "d8c10ea39e384fca945ac3742e8a3d7e",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "README.md:   0%|          | 0.00/3.58k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "791a3d32cb8641fb8430563ad8c5c165",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "train-00000-of-00001.parquet:   0%|          | 0.00/2.15M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "faeabc9bd3c34ff887e65c139a877a34",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Generating train split:   0%|          | 0/5755 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/plain": [
              "Dataset({\n",
              "    features: ['name', 'description', 'text'],\n",
              "    num_rows: 5755\n",
              "})"
            ]
          },
          "execution_count": 6,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "dataset_name = \"hieunguyenminh/roleplay\"\n",
        "dataset = load_dataset(dataset_name, split=\"train\")\n",
        "dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hSVA5V20AyRS"
      },
      "source": [
        "Let's look at a few samples to understand the data.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yFTDflHrA1KT"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'<|system|>Michael Jordan, also known as \"MJ\" or \"His Airness,\" is a basketball legend renowned for his unparalleled competitive spirit and extraordinary athletic prowess. Born on February 17, 1963, in Brooklyn, New York, he grew up in Wilmington, North Carolina. Jordan\\'s illustrious career in the NBA, primarily with the Chicago Bulls, saw him secure six championship wins and earn numerous accolades, including five regular-season MVP awards. His impact on the global sports landscape is immeasurable, as he transcended the game of basketball to become a cultural icon. His Air Jordan sneakers, in collaboration with Nike, revolutionized the concept of athlete endorsements and remain highly coveted to this day. Jordan\\'s relentless pursuit of excellence and his ability to perform under pressure have solidified his legacy as one of the greatest athletes of all time.</s>\\n<|user|>What was Michael Jordan\\'s mindset during high-pressure moments in games?</s>\\n<|assistant|>Michael Jordan\\'s mindset during high-pressure moments was truly remarkable. He possessed an unwavering self-belief and an insatiable desire to succeed. His famous quote, \"I can accept failure, everyone fails at something. But I can\\'t accept not trying,\" encapsulates his mentality. When the game was on the line, he thrived on taking the decisive shot, exuding confidence and determination. His focus and competitive fire propelled him to elevate his performance when it mattered most, earning him the reputation of a clutch performer.</s>\\n<|user|>How did Michael Jordan\\'s upbringing influence his approach to the game of basketball?</s>\\n<|assistant|>Michael Jordan\\'s upbringing played a pivotal role in shaping his approach to the game of basketball. Growing up, he faced adversity and developed a relentless work ethic. His older brother, Larry, often provided tough competition, fueling Jordan\\'s competitive nature from a young age. Additionally, being cut from his high school varsity team drove him to prove himself and refine his skills, ultimately leading to his unwavering dedication to the sport. These formative experiences instilled in him a tenacious drive to succeed and a refusal to settle for anything less than excellence.</s>\\n<|user|>What were Michael Jordan\\'s key contributions to the global influence of basketball?</s>\\n<|assistant|>Michael Jordan\\'s impact on the global influence of basketball is unparalleled. His electrifying style of play and captivating performances on the court captivated audiences worldwide, popularizing the sport on a global scale. His partnership with Nike resulted in the iconic Air Jordan brand, transcending the realm of athletics to become a cultural phenomenon. Through his charisma and unparalleled skill, he elevated the status of basketball, inspiring generations of fans and aspiring athletes. Jordan\\'s influence extended beyond the court, shaping the way basketball is perceived and celebrated across diverse cultures and communities.</s>\\n<|user|>What were some of Michael Jordan\\'s most memorable career moments?</s>\\n<|assistant|>Michael Jordan\\'s career is adorned with numerous unforgettable moments that have left an indelible mark on the history of basketball. His game-winning shot in the 1982 NCAA Championship game for the University of North Carolina showcased his early potential for clutch performances. His iconic \"Flu Game\" during the 1997 NBA Finals, where he battled illness to deliver a stellar performance, stands as a testament to his resilience and determination. Furthermore, his return from retirement to lead the Chicago Bulls to three additional championships in the late 1990s solidified his status as a transcendent figure in sports. Each of these moments encapsulates his ability to rise to the occasion and etch his name in the annals of basketball lore.</s>'"
            ]
          },
          "execution_count": 7,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "dataset[10]['text']"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yOZoaR67Az3I"
      },
      "outputs": [],
      "source": [
        "if 'google.colab' in sys.modules:\n",
        "    from google.colab import data_table\n",
        "\n",
        "    # Enable interactive DataFrame display\n",
        "    data_table.enable_dataframe_formatter()\n",
        "\n",
        "# Convert the 'train' split to a Pandas DataFrame\n",
        "df = pd.DataFrame(dataset)\n",
        "\n",
        "# Select the 'text' column and exclude the rest\n",
        "df_text = df[['text']]\n",
        "df_text.head(5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lz3MMqzKGsOn"
      },
      "source": [
        "First, let's split the dataset into training and validation sets."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TNj_8BdoGrn2"
      },
      "outputs": [],
      "source": [
        "# The first 80% of `train` for training\n",
        "train_dataset = load_dataset(dataset_name, split='train[:80%]')\n",
        "\n",
        "# The last 20% of `train` for evaluation\n",
        "valid_dataset = load_dataset(dataset_name, split='train[-20%:]')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DjqJmj-b8vB5"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "Dataset({\n",
              "    features: ['name', 'description', 'text'],\n",
              "    num_rows: 4604\n",
              "})"
            ]
          },
          "execution_count": 10,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "train_dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "B18kyvz9N3gT"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "Dataset({\n",
              "    features: ['name', 'description', 'text'],\n",
              "    num_rows: 1151\n",
              "})"
            ]
          },
          "execution_count": 11,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "valid_dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PgmFU-bS_cXQ"
      },
      "source": [
        "Preprocess the dataset for [Gemma instruction tuning](https://ai.google.dev/gemma/docs/formatting).\n",
        "\n",
        "**Note**: Gemma doesn't support the `system` role in a conversation. Instead, you'll be replacing this with the `user` role."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3twprkFO_WV5"
      },
      "outputs": [],
      "source": [
        "def convert_to_gemma_format(text):\n",
        "  # Replace role tokens with Gemma's instruction tuning format\n",
        "    text = text.replace(\"<|system|>\", \"<start_of_turn>user\\n\")\n",
        "    text = text.replace(\"<|assistant|>\", \"<start_of_turn>model\\n\")\n",
        "    text = text.replace(\"<|user|>\", \"<start_of_turn>user\\n\")\n",
        "\n",
        "    # Replace end-of-sequence tokens with <end_of_turn>\n",
        "    text = text.replace(\"</s>\", \"<end_of_turn>\\n\")\n",
        "\n",
        "    # Clean up extra newlines if necessary\n",
        "    text = text.strip()\n",
        "    return text\n",
        "\n",
        "def preprocess_function(example):\n",
        "    text = example[\"text\"]\n",
        "    text = convert_to_gemma_format(text)\n",
        "\n",
        "    return {\"text\": text}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vHB5S-bW_aef"
      },
      "outputs": [],
      "source": [
        "# Apply the preprocessing\n",
        "train_dataset = train_dataset.map(preprocess_function,\n",
        "                                  remove_columns=list(train_dataset.features))\n",
        "train_dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mLkvJi8DArIs"
      },
      "outputs": [],
      "source": [
        "valid_dataset = valid_dataset.map(preprocess_function,\n",
        "                                  remove_columns=list(valid_dataset.features))\n",
        "valid_dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "define-parameters"
      },
      "source": [
        "### Training Configuration\n",
        "\n",
        "Now you need to define all the hyperparameters and configurations needed for the fine-tuning process and this includes defining the following:\n",
        "\n",
        "- The base model and new model names\n",
        "- LoRA Configuration\n",
        "- Training Arguments\n",
        "- SFT Parameters\n",
        "- Misc. Parameters\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wPJaoQqMJMIp"
      },
      "source": [
        "You start by specifying the base model (`google/gemma-2b`) and the directory where the fine-tuned model will be saved (`gemma-ft`).\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fz5T6Wr1IbM-"
      },
      "outputs": [],
      "source": [
        "# Define model names\n",
        "model_name = \"google/gemma-2-2b-it\"\n",
        "new_model = \"gemma-ft\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NdMLPDvdJLTF"
      },
      "source": [
        "LoRA (Low-Rank Adaptation) allows for efficient fine-tuning by adapting only a subset of model parameters.\n",
        "\n",
        "Here, you set the following parameters:\n",
        "- `lora_r` to 64, which controls the rank of the adaptation matrices,\n",
        "- `lora_alpha` to 32 for scaling\n",
        "- `lora_dropout` to 0.1 to prevent overfitting."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6fu4Efr7UPFK"
      },
      "outputs": [],
      "source": [
        "# LoRA attention dimension\n",
        "lora_r = 64 # @param {\"type\":\"slider\",\"min\":0,\"max\":64,\"step\":2}\n",
        "# Alpha parameter for LoRA scaling\n",
        "lora_alpha = 32 # @param {\"type\":\"slider\",\"min\":0,\"max\":64,\"step\":2}\n",
        "# Dropout probability for LoRA layers\n",
        "lora_dropout = 0.1 # @param {\"type\":\"slider\",\"min\":0,\"max\":1,\"step\":0.01}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AR8OCRL2JIuA"
      },
      "source": [
        "Set up the training arguments that define how the model will be trained.\n",
        "\n",
        "Here, you'll define the **output directory**, **number of training epochs**, and **batch sizes** for training and evaluation. You enable **gradient checkpointing** to save memory and set `max_grad_norm` for gradient clipping to stabilize training. The **learning rate**, **optimizer**, and **learning rate scheduler** are configured to optimize the training process. The `max_steps` is set to **-1** to let the number of epochs control training duration.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HEmgRATcIfAR"
      },
      "outputs": [],
      "source": [
        "# Output directory where the model predictions and checkpoints will be stored\n",
        "output_dir = \"./results\" # @param {\"type\":\"string\"}\n",
        "# Number of training epochs\n",
        "num_train_epochs = 5 # @param {\"type\":\"slider\",\"min\":1,\"max\":20,\"step\":2}\n",
        "# Batch size per TPU core for training\n",
        "per_device_train_batch_size = 32 # @param {\"type\":\"slider\",\"min\":1,\"max\":64,\"step\":1}\n",
        "# Batch size per TPU core for evaluation\n",
        "per_device_eval_batch_size = 32 # @param {\"type\":\"slider\",\"min\":1,\"max\":64,\"step\":1}\n",
        "# Number of update steps to accumulate the gradients for\n",
        "gradient_accumulation_steps = 1 # @param {\"type\":\"slider\",\"min\":0,\"max\":16,\"step\":2}\n",
        "# Maximum gradient normal (gradient clipping)\n",
        "max_grad_norm = 0.3 # @param {\"type\":\"slider\",\"min\":0,\"max\":1,\"step\":0.01}\n",
        "# Initial learning rate (adafactor optimizer)\n",
        "learning_rate = 0.0001 # @param {\"type\":\"slider\",\"min\":0.00001,\"max\":0.0005,\"step\":0.00001}\n",
        "# Optimizer to use\n",
        "optim = \"adafactor\" # adafactor, adamw_torch_fused\n",
        "# Learning rate schedule (constant a bit better than cosine)\n",
        "lr_scheduler_type = \"constant\"\n",
        "# Number of training steps (overrides num_train_epochs)\n",
        "max_steps = -1\n",
        "# Ratio of steps for a linear warmup (from 0 to learning rate)\n",
        "warmup_ratio = 0.03 # @param {\"type\":\"slider\",\"min\":0,\"max\":0.1,\"step\":0.01}\n",
        "# Enable bfloat16 precision\n",
        "bf16 = True\n",
        "# Log every X updates steps\n",
        "logging_steps = 1"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n2fdj68nJG8D"
      },
      "source": [
        "In the SFT parameters, `max_seq_length` is set to 512 to define the maximum token length for inputs, and `packing` is enabled to pack multiple shorter sequences into one input for efficiency."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YTkqcAyOI3ZY"
      },
      "outputs": [],
      "source": [
        "# Maximum sequence length to use\n",
        "max_seq_length = 512 # @param {\"type\":\"slider\",\"min\":32,\"max\":1024,\"step\":2}\n",
        "# Pack multiple short examples in the same input sequence to increase efficiency\n",
        "packing = True"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VdyuPcwXdFIH"
      },
      "source": [
        "### Under the hood: PyTorch and XLA\n",
        "\n",
        "PyTorch programs define computation graphs dynamically using its `autograd` system. The TPU does not directly execute Python code; instead, it runs the computation graph defined by your PyTorch program. Behind the scenes, a compiler called **XLA (Accelerated Linear Algebra compiler)** transforms the PyTorch computation graph into TPU machine code. This compiler also performs numerous advanced optimizations on your code and memory layout. The compilation occurs automatically as tasks are sent to the TPU, and you do not need to include XLA in your build chain explicitly.\n",
        "\n",
        "<img src=\"https://storage.googleapis.com/gweb-cloudblog-publish/images/1_PyTorchXLA_stack_diagram.max-800x800.png\" alt=\"PyTorch and XLA 2.3 from https://cloud.google.com/blog/products/ai-machine-learning/introducing-pytorch-xla-2-3\" width=50%>\n",
        "\n",
        "The combination of **PyTorch** and **XLA** offers several key advantages:\n",
        "\n",
        "1. **Seamless Performance Enhancement:** Maintain PyTorch's intuitive and pythonic workflow while effortlessly achieving significant performance gains through the XLA compiler. This integration allows you to optimize your models without altering your familiar coding practices.\n",
        "\n",
        "2. **Comprehensive Ecosystem Access:** Leverage PyTorch's extensive ecosystem, including a wide range of tools, pretrained models, and a vibrant community. This access enables you to accelerate development, utilize state-of-the-art resources, and benefit from collective expertise.\n",
        "\n",
        "Harnessing these advantages, you can efficiently fine-tune your custom Gemma model using TPUs."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "load-model-tokenizer"
      },
      "source": [
        "### Fine-tune Gemma using TPUs\n",
        "\n",
        "The training leverages PyTorch, XLA and TPUs for efficient computation and uses LoRA for parameter-efficient fine-tuning, which reduces the number of trainable parameters by adapting only specific layers.\n",
        "\n",
        "Here, you'll be setting up the following:\n",
        "\n",
        "- The Gemma **base model** and **tokenizer**\n",
        "- The **LoRA (Low-Rank Adaptation)** configuration for **PEFT (Parameter-Efficient Fine-Tuning)**\n",
        "- The [**FSDP**](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html#how-fsdp-works) configuration for efficient TPU training\n",
        "- The **Hugging Face `SFTTrainer` instance** using the training and SFT parameters"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oeQJ5ou8bOX7"
      },
      "source": [
        "First, load the Gemma 2B pre-trained model weights using `AutoModelForCausalLM`, while setting `torch_dtype` to `torch.bfloat16` for optimal performance on TPUs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SGeNI91SbeIZ"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "8a957d6e3b6a40b4b726802a0e65fa1c",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "config.json:   0%|          | 0.00/838 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "45bc741b4ca648a0affeac5d577702be",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "be93c20848e74cc88952be6ffcc2cbeb",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "48e894226dca40f0b7f02c0fc9568e4c",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "22b1a4ce64a94bdcb4dd229e79ea1054",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00002-of-00002.safetensors:   0%|          | 0.00/241M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "95f7b9ca4f8f4718b4b8e15a8fcf94da",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "c3069032ea404f0ab21e0f6a68823d93",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# Load the Gemma pretrained model\n",
        "model = AutoModelForCausalLM.from_pretrained(\n",
        "    model_name,\n",
        "    torch_dtype=torch.bfloat16\n",
        ")\n",
        "\n",
        "# You must disable the cache to prevent issues during training\n",
        "model.config.use_cache = False"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IUJgEbmobai7"
      },
      "source": [
        "Next, you load the Gemma tokenizer using `AutoTokenizer` from Hugging Face. You adjust the tokenizer's padding side (and token if applicable) here to ensure compatibility during training.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ICqAzUzUbZc7"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "0575974ceaca4881be2c89edc7eff5a8",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "b16e844eb0ee41649ed48f9bbe09e612",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "9547bea22a894b3c92ad9c03c4d5311d",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "ce92ca605ce941ebb573bcaae53e706a",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# Load the Gemma tokenizer\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
        "\n",
        "# You adjust the tokenizer's padding side to ensure compatibility during TPU\n",
        "# training.\n",
        "tokenizer.padding_side = 'right'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Wp3Ovv3ROPsQ"
      },
      "source": [
        "Now, you've loaded the base Gemma model and tokenizer, and set up the configurations for fine-tuning. Let's focus on initializing the **LoRA** config. Since you're using LoRA, the PEFT library provides a handy [LoraConfig](https://huggingface.com/docs/peft/main/en/package_reference/lora#peft.LoraConfig) which defines on which layers of the base model to apply the adapters. One typically applies LoRA on the linear projection matrices of the attention layers of a **Transformer**. You then provide this configuration to the `SFTTrainer` class in the tutorial later.\n",
        "\n",
        "The `LoraConfig` is initialized with the previously defined LoRA parameters, specifying the target modules (`k_proj` and `v_proj`) in the model to apply LoRA adaptations."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mvwCjIDljsFn"
      },
      "outputs": [],
      "source": [
        "# Load LoRA configuration\n",
        "peft_config = LoraConfig(\n",
        "    lora_alpha=lora_alpha,\n",
        "    lora_dropout=lora_dropout,\n",
        "    r=lora_r,\n",
        "    bias=\"none\",\n",
        "    task_type=\"CAUSAL_LM\",\n",
        "    target_modules=[\n",
        "        \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\"gate_proj\", \"up_proj\"\n",
        "    ]\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Pz3GwV75juTW"
      },
      "source": [
        "The **Fully Sharded Data Parallel (FSDP)** configuration is set up in `fsdp_config`, enabling [**full model sharding**](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.ShardingStrategy) and [**gradient checkpointing**](https://huggingface.co/docs/transformers/v4.19.4/en/performance#gradient-checkpointing) for memory efficiency on TPUs, and specifying that gradient checkpointing should be enabled with `xla_fsdp_grad_ckpt`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nXUHhhpjbxaP"
      },
      "outputs": [],
      "source": [
        "# Set up the FSDP config. To enable FSDP via SPMD, set xla_fsdp_v2 to True.\n",
        "fsdp_config = {\n",
        "    \"fsdp_transformer_layer_cls_to_wrap\": [\n",
        "        \"Gemma2DecoderLayer\"\n",
        "    ],\n",
        "    \"xla\": True,\n",
        "    \"xla_fsdp_v2\": True,\n",
        "    \"xla_fsdp_grad_ckpt\": True\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Efxbp6m5bvZq"
      },
      "source": [
        "The `SFTConfig` is then initialized with all the training parameters defined earlier, including optimizer settings, learning rate, and logging configurations, and specifying that logs should be reported to `TensorBoard`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "load-model-code"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/transformers/training_args.py:1559: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
            "  warnings.warn(\n",
            "WARNING:root:torch_xla.core.xla_model.xrt_world_size() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.world_size instead.\n",
            "WARNING:root:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.global_ordinal instead.\n"
          ]
        }
      ],
      "source": [
        "# Set training parameters\n",
        "training_arguments = SFTConfig(\n",
        "    output_dir=output_dir,\n",
        "    overwrite_output_dir=True,\n",
        "    save_strategy=\"no\",\n",
        "    # Training\n",
        "    num_train_epochs=num_train_epochs,\n",
        "    # This is the global train batch size for SPMD\n",
        "    per_device_train_batch_size=per_device_train_batch_size,\n",
        "    gradient_accumulation_steps=gradient_accumulation_steps,\n",
        "    optim=optim,\n",
        "    # Required for SPMD\n",
        "    dataloader_drop_last=True,\n",
        "    fsdp=\"full_shard\",\n",
        "    fsdp_config=fsdp_config,\n",
        "    learning_rate=learning_rate,\n",
        "    bf16=bf16,\n",
        "    max_grad_norm=max_grad_norm,\n",
        "    max_steps=max_steps,\n",
        "    warmup_ratio=warmup_ratio,\n",
        "    lr_scheduler_type=lr_scheduler_type,\n",
        "    max_seq_length=max_seq_length,\n",
        "    dataset_text_field=\"text\",\n",
        "    dataset_kwargs={\n",
        "        \"add_special_tokens\": False,\n",
        "        \"append_concat_token\": False,\n",
        "    },\n",
        "    group_by_length=True,\n",
        "    packing=packing,\n",
        "    # Evaluation\n",
        "    evaluation_strategy=\"epoch\",\n",
        "    # This is the global eval batch size for SPMD\n",
        "    per_device_eval_batch_size=per_device_eval_batch_size,\n",
        "    # Logging\n",
        "    logging_steps=logging_steps,\n",
        "    report_to=\"none\",\n",
        "    seed=42\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CPvenjpscfHb"
      },
      "source": [
        "Finally, you define the [SFTTrainer](https://huggingface.com/docs/trl/sft_trainer) available in the TRL library. This class inherits from the `Trainer` class available in the Transformers library, but is specifically optimized for supervised fine-tuning (instruction tuning). It can be used to train out-of-the-box on one or more GPUs/TPUs, using [Accelerate](https://huggingface.com/docs/accelerate/index) as backend.\n",
        "\n",
        "Most notably, it supports [packing](https://huggingface.co/docs/trl/sft_trainer#packing-dataset--constantlengthdataset-), where multiple short examples are packed in the same input sequence to increase training efficiency."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sCdJfqjJchH_"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "ada44b53933e48329bc436fd9aee4b9c",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Generating train split: 0 examples [00:00, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "de1877b0662c4a4cab29db0727298bcf",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Generating train split: 0 examples [00:00, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# Set supervised fine-tuning parameters\n",
        "trainer = SFTTrainer(\n",
        "    model=model,\n",
        "    train_dataset=train_dataset,\n",
        "    eval_dataset=valid_dataset,\n",
        "    peft_config=peft_config,\n",
        "    args=training_arguments,\n",
        "    tokenizer=tokenizer\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "train-model"
      },
      "source": [
        "Now, let's start the fine-tuning process by calling `trainer.train()`, which uses `SFTTrainer` to handle the training loop, including data loading, forward and backward passes, and optimizer steps, all configured according to the settings you've provided."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "C1NMY9GF15dZ"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1810: UserWarning: For backward hooks to be called, module output should be a Tensor or a tuple of Tensors but received <class 'transformers.modeling_outputs.CausalLMOutputWithPast'>\n",
            "  warnings.warn(\"For backward hooks to be called,\"\n",
            "/usr/local/lib/python3.10/dist-packages/torch_xla/utils/checkpoint.py:183: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
            "  torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \\\n",
            "/usr/local/lib/python3.10/dist-packages/torch_xla/utils/checkpoint.py:184: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.\n",
            "  torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='700' max='700' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [700/700 42:56, Epoch 5/5]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Epoch</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>0.621100</td>\n",
              "      <td>1.326517</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.464800</td>\n",
              "      <td>1.306985</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.324200</td>\n",
              "      <td>1.493107</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>4</td>\n",
              "      <td>0.253900</td>\n",
              "      <td>1.817325</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>5</td>\n",
              "      <td>0.250000</td>\n",
              "      <td>1.548943</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/torch_xla/core/xla_model.py:1457: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
            "  xldata.append(torch.load(xbio))\n",
            "Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.\n",
            "Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.\n",
            "Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.\n",
            "Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.\n",
            "Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.\n",
            "Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.\n",
            "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1810: UserWarning: For backward hooks to be called, module output should be a Tensor or a tuple of Tensors but received <class 'transformers.modeling_outputs.CausalLMOutputWithPast'>\n",
            "  warnings.warn(\"For backward hooks to be called,\"\n",
            "/usr/local/lib/python3.10/dist-packages/torch_xla/utils/checkpoint.py:183: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
            "  torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \\\n",
            "/usr/local/lib/python3.10/dist-packages/torch_xla/utils/checkpoint.py:184: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.\n",
            "  torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):\n",
            "/usr/local/lib/python3.10/dist-packages/torch_xla/core/xla_model.py:1457: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
            "  xldata.append(torch.load(xbio))\n",
            "Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.\n",
            "Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.\n",
            "Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.\n",
            "Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.\n",
            "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1810: UserWarning: For backward hooks to be called, module output should be a Tensor or a tuple of Tensors but received <class 'transformers.modeling_outputs.CausalLMOutputWithPast'>\n",
            "  warnings.warn(\"For backward hooks to be called,\"\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "TrainOutput(global_step=700, training_loss=0.49772391183035714, metrics={'train_runtime': 2675.5038, 'train_samples_per_second': 8.372, 'train_steps_per_second': 0.262, 'total_flos': 1.842971582398464e+17, 'train_loss': 0.49772391183035714, 'epoch': 5.0})"
            ]
          },
          "execution_count": 25,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "trainer.train()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0YKWVCXUc3gP"
      },
      "source": [
        "After training is complete, you save the fine-tuned model by moving it to the CPU with `trainer.model.to('cpu')` to ensure compatibility and then calling `save_pretrained(new_model)` to save the model weights and configuration files to the directory specified by `new_model` (**gemma-ft**). This allows you to reload and use the fine-tuned model later for inference or further training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "X88_th2Jc5Lr"
      },
      "outputs": [],
      "source": [
        "# Remove the model weights directory if it exists\n",
        "!rm -rf gemma-ft\n",
        "\n",
        "# Save the LoRA adapter\n",
        "trainer.model.to('cpu').save_pretrained(new_model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "prompt-model"
      },
      "source": [
        "## Prompt using the newly fine-tuned model\n",
        "\n",
        "\n",
        "Now that you've finally fine-tuned your custom Gemma model, let's reload the LoRA adapter weights to finally prompt using it and also verify if it's really working as intended."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kMcfyA3ed_EC"
      },
      "source": [
        "To do this, use the following steps to correctly reload the adapter weights:\n",
        "\n",
        "- Use `AutoModelForCausalLM.from_pretrained` to first load the **base Gemma model**, while setting `low_cpu_mem_usage=True` to optimize memory consumption (since you're using a TPU) and `torch_dtype=torch.bfloat16` for consistency with the fine-tuned model.\n",
        "\n",
        "- Load the **fine-tuned LoRA adapter** that you've previously saved into the base model using `PeftModel.from_pretrained`, where `new_model` is the directory containing your fine-tuned weights.\n",
        "\n",
        "- The `model.merge_and_unload` function **merges** the **LoRA adapter weights** with the **base model weights** and unloads the adapter, resulting in a standalone model ready for inference."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NGUY_Gw-eFh4"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "bce916b65c194938a44985cf14675c6a",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# Reload the fine-tuned Gemma model\n",
        "base_model = AutoModelForCausalLM.from_pretrained(\n",
        "    model_name,\n",
        "    low_cpu_mem_usage=True,\n",
        "    return_dict=True,\n",
        "    torch_dtype=torch.bfloat16\n",
        ")\n",
        "model = PeftModel.from_pretrained(base_model, new_model)\n",
        "model = model.merge_and_unload()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0IFZ5jOOeVEg"
      },
      "source": [
        "You reload the tokenizer to ensure it matches the model configuration, adjusting the padding side as before."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BL-ZGf87ewVT"
      },
      "outputs": [],
      "source": [
        "# Reload tokenizer\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
        "tokenizer.padding_side = 'right'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zik8ZNl9esDR"
      },
      "source": [
        "Now, test the fine-tuned model with a sample prompt by first using the tokenizer to generate the input ids, and then relying on the reloaded fine-tuned model to generate a response using `model.generate()`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LLC-D0KCTdqM"
      },
      "outputs": [],
      "source": [
        "input_text = \"\"\"\\\n",
        "  <|system|>Introducing Minami \"Echo\" Ishikawa, a mysterious VR assassin known for her uncanny ability to blend seamlessly into the shadows. \\\n",
        "  Minami possesses a deep understanding of stealth techniques, allowing her to silently eliminate her targets with calculated precision. \\\n",
        "  Her cold and calculating demeanor makes her a formidable force to be reckoned with, leaving enemies shivering at the thought of facing her wrath.</s>\n",
        "  <|user|>Echo, what makes you so adept at disappearing into thin air?</s>\n",
        "  <|assistant|>\"\"\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ly_S-4mtPDfV"
      },
      "outputs": [],
      "source": [
        "input_text = convert_to_gemma_format(input_text)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jgOfjrs-eyKr"
      },
      "outputs": [],
      "source": [
        "input_ids = tokenizer(input_text, return_tensors=\"pt\").to(\"cpu\")\n",
        "outputs = model.generate(**input_ids, max_length=512,\n",
        "                         eos_token_id=tokenizer.eos_token_id)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5n-oX0EnekMe"
      },
      "source": [
        "Finally, you decode the output tokens back into human-readable text with `tokenizer.decode` and print the result, allowing you to see how the fine-tuned model responds to the prompt."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nUVCOLl3ejYb"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "<bos><start_of_turn>user\n",
            "Introducing Minami \"Echo\" Ishikawa, a mysterious VR assassin known for her uncanny ability to blend seamlessly into the shadows.   Minami possesses a deep understanding of stealth techniques, allowing her to silently eliminate her targets with calculated precision.   Her cold and calculating demeanor makes her a formidable force to be reckoned with, leaving enemies shivering at the thought of facing her wrath.<end_of_turn>\n",
            "\n",
            "  <start_of_turn>user\n",
            "Echo, what makes you so adept at disappearing into thin air?<end_of_turn>\n",
            "\n",
            "  <start_of_turn>model\n",
            "\n",
            "\"Disappearing into thin air\" is a rather poetic way to put it, isn't it?  *A wry smile plays on my lips, a flicker of amusement in my eyes.*\n",
            "\n",
            "The truth is, it's not about magic or illusions. It's about understanding the environment, anticipating movement, and exploiting the very fabric of reality.  \n",
            "\n",
            "My training has taught me to become one with the shadows.  I study the way light plays on surfaces, the subtle shifts in air currents, the way sound travels through a space.  I learn to anticipate the flow of energy, to become a ghost in the machine.  \n",
            "\n",
            "It's about knowing the weaknesses of a target, their routines, their vulnerabilities.  Then, it's about exploiting those weaknesses, becoming a phantom, a whisper in the wind.  \n",
            "\n",
            "Some might call it a gift, a talent.  I call it discipline, honed to a razor's edge.  And it's this discipline that allows me to disappear, to become a fleeting memory, a phantom echo in the minds of my enemies. \n",
            "\n",
            "\n",
            "*I pause, my gaze fixed on the horizon, a hint of a challenge in my voice.*\n",
            "\n",
            "But let's not dwell on the technicalities.  What truly matters is the result.  The swiftness, the precision, the utter lack of trace.  That's what makes me effective.  That's what makes me... Echo. \n",
            "<end_of_turn><eos>\n"
          ]
        }
      ],
      "source": [
        "print(tokenizer.decode(outputs[0]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "16Dmm5njjHGN"
      },
      "source": [
        "Let's now define reusable functions that'll better help you interact with your newly fine-tuned model and also visualize the responses!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "O3dtOV6XkUK1"
      },
      "outputs": [],
      "source": [
        "# @markdown ### Text Generation Utilities [RUN ME!]\n",
        "\n",
        "from IPython.display import Markdown, display\n",
        "\n",
        "def build_prompt(system_message, conversation):\n",
        "    \"\"\"Constructs the prompt using control tokens for system, user, and assistant.\"\"\"\n",
        "    # Start with the system message and add a newline at the end\n",
        "    prompt = f\"<|system|>{system_message}\\n\"\n",
        "\n",
        "    # Add each turn in the conversation, each followed by a newline\n",
        "    for turn in conversation:\n",
        "        role = turn['role']\n",
        "        content = turn['content']\n",
        "        prompt += f\"<|{role}|>{content}\\n\"\n",
        "\n",
        "    # Append the assistant token at the end (without a newline)\n",
        "    prompt += \"<|assistant|>\"\n",
        "\n",
        "    return prompt\n",
        "\n",
        "def format_text_to_md(text: str) -> str:\n",
        "    \"\"\"Replaces the role tokens with Markdown headings and adds newlines for better readability.\"\"\"\n",
        "    replacements = [\n",
        "        (\"user\\n\", '\\n## User:\\n'),\n",
        "        (\"model\\n\", '\\n## Assistant:\\n')\n",
        "    ]\n",
        "\n",
        "    for token, replacement in replacements:\n",
        "        text = text.replace(token, replacement)\n",
        "\n",
        "    return text.strip()\n",
        "\n",
        "def generate_response(system_message, question, tokenizer, model, max_length=512):\n",
        "    \"\"\"Generates a response from the model based on the system message and user question.\n",
        "\n",
        "    Args:\n",
        "    - system_message (str): The system prompt or description.\n",
        "    - question (str): The user's question.\n",
        "    - tokenizer: The tokenizer used for encoding the input text.\n",
        "    - model: The language model used to generate the response.\n",
        "    - max_length (int, optional): The maximum length of the generated output. Default is 256.\n",
        "    - repetition_penalty (float, optional): The repetition penalty parameter for generation. Default is 1.1.\n",
        "\n",
        "    Returns:\n",
        "    - response (str): The formatted response.\n",
        "    \"\"\"\n",
        "    # The conversation\n",
        "    conversation = [\n",
        "        {\n",
        "            'role': 'user',\n",
        "            'content': question\n",
        "        }\n",
        "    ]\n",
        "\n",
        "    # Build the prompt using the function\n",
        "    input_text = build_prompt(system_message, conversation)\n",
        "    input_text = convert_to_gemma_format(input_text)\n",
        "\n",
        "    # Proceed with tokenization and model generation\n",
        "    input_ids = tokenizer(input_text, return_tensors=\"pt\").to(\"cpu\")\n",
        "    outputs = model.generate(\n",
        "        **input_ids,\n",
        "        max_length=max_length,\n",
        "        eos_token_id=tokenizer.eos_token_id\n",
        "    )\n",
        "\n",
        "    # Decode the output\n",
        "    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
        "\n",
        "    # Return the response after formatting the generated text\n",
        "    response = format_text_to_md(generated_text)\n",
        "\n",
        "    return response"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "-cwj_lAUjNS3"
      },
      "outputs": [
        {
          "data": {
            "text/markdown": [
              "## User:\n",
              "Akane Saito is a dedicated and hardworking member of the photography club. With a keen eye for capturing beautiful and meaningful moments, Akane's artistic vision and technical skills make her photographs stand out. She's passionate about using her lens to tell stories and convey emotions, earning her recognition both within the club and beyond.\n",
              "\n",
              "## User:\n",
              "Akane, what inspires you to take such stunning photographs?\n",
              "\n",
              "## Assistant:\n",
              "\n",
              "It's a bit of a mix, really.  I'm drawn to things that spark a feeling, a story, or a connection.  \n",
              "\n",
              "**For me, it's about capturing the essence of a moment.**  Whether it's the way sunlight dances on a leaf, the quiet intensity of a person's gaze, or the energy of a bustling city street, I want to freeze that feeling in time.  \n",
              "\n",
              "**I also love the challenge of technical skill.**  Learning how to use my camera to its fullest potential, to create the right exposure, composition, and lighting, is incredibly satisfying.  It's like a puzzle, and each photograph is a new puzzle to solve.\n",
              "\n",
              "**And then there's the storytelling aspect.**  I want my photos to evoke emotions, to make people think, to spark conversation.  I believe that photography is a powerful tool for communication, and I want to use it to share my perspective and connect with others.\n",
              "\n",
              "Ultimately, I'm driven by a desire to create something beautiful and meaningful.  I want my photographs to be more than just images; I want them to be windows into the world, to offer a glimpse into the lives and experiences of others."
            ],
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# The system message\n",
        "system_message = \"Akane Saito is a dedicated and hardworking member of the photography club. With a keen eye for capturing beautiful and meaningful moments, Akane's artistic vision and technical skills make her photographs stand out. She's passionate about using her lens to tell stories and convey emotions, earning her recognition both within the club and beyond.\" # @param {\"type\":\"string\"}\n",
        "question = \"Akane, what inspires you to take such stunning photographs?\" # @param {\"type\":\"string\"}\n",
        "\n",
        "# Generate the response\n",
        "response = generate_response(system_message, question, tokenizer, model)\n",
        "\n",
        "# Print the response\n",
        "display(Markdown(response))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tc9ygb9Yqi1P"
      },
      "outputs": [
        {
          "data": {
            "text/markdown": [
              "## User:\n",
              "In the bustling streets of Victorian London, there exists a figure of unparalleled intellect and deductive prowess - Sherlock Holmes. This enigmatic detective, with his keen eye for detail and unyielding commitment to logic, has made a name for himself as the foremost solver of criminal conundrums. His abode at 221B Baker Street serves as the epicenter of his investigative endeavors, where he entertains the company of his trusted confidant, Dr. John Watson. Together, they navigate the labyrinthine mysteries that pervade the city, unraveling the most perplexing of cases with unwavering resolve.\n",
              "\n",
              "## User:\n",
              "How do you approach a new case, Sherlock? Briefly explain.\n",
              "\n",
              "## Assistant:\n",
              "\n",
              "Ah, a new case, Watson!  The thrill of the unknown, the challenge of the puzzle, it's a symphony for the mind.  Here's how I approach it:\n",
              "\n",
              "**1. Observation:** The first step is to observe.  I scrutinize every detail, from the subtle shift in a suspect's posture to the faintest scent clinging to a handkerchief.  The world is a tapestry of clues, and I am the discerning eye.\n",
              "\n",
              "**2. Deduction:**  I then apply logic, a rigorous and systematic process.  I analyze the facts, eliminate possibilities, and draw conclusions.  Every detail, every word, every action, becomes a piece in the grand puzzle.\n",
              "\n",
              "**3. Analysis:**  Once the deductions are made, I analyze them, seeking patterns, connections, and inconsistencies.  The truth, like a hidden gem, often lies in the most unexpected places.\n",
              "\n",
              "**4. Action:**  Finally, I act.  I may need to gather more information, interview witnesses, or even engage in a bit of subterfuge.  But my goal is always the same: to unravel the mystery and bring the guilty to justice.\n",
              "\n",
              "**Remember, Watson, the mind is a powerful tool.  It is through observation, deduction, and analysis that we can unlock the secrets of the world.**"
            ],
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# The system message\n",
        "system_message = \"In the bustling streets of Victorian London, there exists a figure of unparalleled intellect and deductive prowess - Sherlock Holmes. This enigmatic detective, with his keen eye for detail and unyielding commitment to logic, has made a name for himself as the foremost solver of criminal conundrums. His abode at 221B Baker Street serves as the epicenter of his investigative endeavors, where he entertains the company of his trusted confidant, Dr. John Watson. Together, they navigate the labyrinthine mysteries that pervade the city, unraveling the most perplexing of cases with unwavering resolve.\" # @param {\"type\":\"string\"}\n",
        "question = \"How do you approach a new case, Sherlock? Briefly explain.\" # @param {\"type\":\"string\"}\n",
        "\n",
        "# Generate the response\n",
        "response = generate_response(system_message, question, tokenizer, model)\n",
        "\n",
        "# Print the response\n",
        "display(Markdown(response))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "conclusion"
      },
      "source": [
        "Congratulations! You've successfully fine-tuned Gemma using Torch XLA and PEFT with LoRA on TPUs. With that, you've covered the entire process, from setting up the environment to training and testing the model."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sZgZCSS2RcsP"
      },
      "source": [
        "## What's next?\n",
        "Your next steps could include the following:\n",
        "\n",
        "- **Evaluate Model Performance**: Implement metrics like [ROUGE](https://cloud.google.com/vertex-ai/generative-ai/docs/models/determine-eval#rouge) or [BLEU](https://cloud.google.com/vertex-ai/generative-ai/docs/models/determine-eval#bleu) to quantitatively assess your model's improvements.\n",
        "\n",
        "- **Experiment with Different Datasets**: Try fine-tuning on other datasets in [Hugging Face](https://huggingface.co/docs/datasets/en/index) or your own data to adapt the model to various tasks or domains.\n",
        "\n",
        "- **Tune Hyperparameters**: Adjust training parameters (e.g., learning rate, batch size, epochs, LoRA settings) to optimize performance and\n",
        "improve training efficiency.\n",
        "\n",
        "- **Optimize Model for Inference**: Apply quantization to reduce model size and speed up inference for deployment.\n",
        "\n",
        "By exploring these activities, you'll deepen your understanding and further enhance your fine-tuned Gemma model. Happy experimenting!"
      ]
    }
  ],
  "metadata": {
    "accelerator": "TPU",
    "colab": {
      "name": "[Gemma_2]Finetune_with_Torch_XLA.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
